package cn.xuqiudong.rpc.spring.reference;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanClassLoaderAware;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;
import cn.xuqiudong.rpc.spring.annotation.XqdReference;

import java.lang.reflect.Field;
import java.util.HashMap;
import java.util.Map;

/**
 * 描述: spring中通过{@link XqdReference} 引用的类的动态注入
 * @author Vic.xu
 * @date 2022-02-28 17:57
 */
public class XqdSpringReferenceBeanProcessor implements BeanFactoryPostProcessor, BeanClassLoaderAware, ApplicationContextAware {

    /**
     * 通过BeanClassLoaderAware 注入
     */
    private ClassLoader classLoader;

    /**
     * 通过ApplicationContextAware 注入
     */
    private ApplicationContext applicationContext;

    /**
     * 保存引用的bean
     */
    private final Map<String, BeanDefinition> XQD_REFERENCE_BEAN_MAP = new HashMap<>();

    private static Logger logger = LoggerFactory.getLogger(XqdSpringReferenceBeanProcessor.class);


    /**
     * 此时bean尚未初始化
     * @param beanFactory
     * @throws BeansException
     */
    @Override
    public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
        //获得所有的定义的bean, 遍历bean中的字段是否通过XqdReference注解 注入，是的话 则构件动态代理类
        String[] beanDefinitionNames = beanFactory.getBeanDefinitionNames();
        for (String beanDefinitionName : beanDefinitionNames) {
            //获得的bean的定义
            BeanDefinition beanDefinition = beanFactory.getBeanDefinition(beanDefinitionName);
            String beanClassName = beanDefinition.getBeanClassName();
            //如果bean不为空，则
            if (beanClassName != null) {
                //根据beanClassName解析出Class， 然后遍历field， 若被XqdReference标记，则通过BeanDefinitionBuilder创建动态代理类
                Class<?> clazz = ClassUtils.resolveClassName(beanClassName, this.classLoader);
                ReflectionUtils.doWithFields(clazz, this::fieldCallback);
            }
        }

        BeanDefinitionRegistry registry=(BeanDefinitionRegistry)beanFactory;
        XQD_REFERENCE_BEAN_MAP.forEach((beanName, beanDefinition) -> {
            if(applicationContext.containsBean(beanName)) {
                logger.info("{} 已经注册到spring上下文", beanName);
                return;
            }
            registry.registerBeanDefinition(beanName, beanDefinition );
            logger.info("成功注册 XqdReference bean：{}到spring", beanName);
        });


    }



    private void fieldCallback(Field field){
        XqdReference annotation = AnnotationUtils.getAnnotation(field, XqdReference.class);
        if (annotation != null) {
            BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(XqdBeanFactory.class);
            builder.setInitMethodName(XqdBeanFactory.INIT_METHOD_NAME);
            builder.addPropertyValue(XqdBeanFactory.INTERFACE_CLASS_FIELD_NAME, field.getType());
            BeanDefinition beanDefinition = builder.getBeanDefinition();
            XQD_REFERENCE_BEAN_MAP.put(field.getName(), beanDefinition);
        }
    }

    @Override
    public void setBeanClassLoader(ClassLoader classLoader) {
        this.classLoader = classLoader;
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        this.applicationContext = applicationContext;
    }
}
